from importlib import resources
from typing import Optional, List

import matplotlib.pyplot as plt
import pandas as pd
import plotly.express as px
from easydict import EasyDict
from matplotlib import rcParams, gridspec, ticker
from matplotlib.dates import DateFormatter
from matplotlib.offsetbox import (OffsetImage, AnnotationBbox)
from rqalpha.mod.rqalpha_mod_sys_analyser.plot.consts import MAX_DD
from rqalpha.mod.rqalpha_mod_sys_analyser.plot.utils import max_dd as _max_dd

from quant.strategy import TradeOp


def plot_trade_signals(trade_signals: pd.DataFrame,
                       x_col: str,
                       y_col: str,
                       frequency: str,
                       grid: Optional[List[float]] = None,
                       show_skip: Optional[bool] = False):
    trade_signals.sort_values([x_col], ascending=True, inplace=True)
    fig, ax = plt.subplots(1, figsize=(30, 8))
    if frequency == '1m':
        trade_signals[x_col] = trade_signals.apply(
            lambda r: r[x_col].strftime('%m-%d %H:%M'), axis=1)
    ax.plot(trade_signals[x_col], trade_signals[y_col])
    # signal_buy = stock_data[stock_data.signal == 1]
    # ax.scatter(signal_buy.date, signal_buy['open'], marker='v', color='red', s=64)
    # signal_sell = stock_data[stock_data.signal == -1]
    # ax.scatter(signal_sell.date, signal_sell['open'], marker='^', color='green', s=64)
    with resources.path(__package__, 'buy.png') as f:
        buy_img = plt.imread(f)
    with resources.path(__package__, 'sell.png') as f:
        sell_img = plt.imread(f)
    with resources.path(__package__, 'open.png') as f:
        open_img = plt.imread(f)
    with resources.path(__package__, 'close.png') as f:
        close_img = plt.imread(f)
    with resources.path(__package__, 'hold.png') as f:
        hold_img = plt.imread(f)
    with resources.path(__package__, 'skip.png') as f:
        skip_img = plt.imread(f)
    img_map = {
        TradeOp.BUY: buy_img,
        TradeOp.SELL: sell_img,
        TradeOp.OPEN: open_img,
        TradeOp.CLOSE: close_img,
        TradeOp.SKIP: skip_img,
        TradeOp.HOLD: hold_img,
        TradeOp.SKIP_BUY: skip_img,
        TradeOp.SKIP_SELL: skip_img,
    }
    for d, p, s, g in zip(trade_signals[x_col], trade_signals[y_col],
                          trade_signals.signal, trade_signals.grid_id):
        if s == TradeOp.HOLD:
            continue
        elif not show_skip and s in [
                TradeOp.SKIP, TradeOp.SKIP_BUY, TradeOp.SKIP_SELL
        ]:
            continue
        else:
            imagebox = OffsetImage(img_map[s], zoom=0.25)
        imagebox.image.axes = ax
        ab = AnnotationBbox(imagebox,
                            xy=(d, p),
                            xybox=(0, 50 if s == TradeOp.SELL else -50),
                            xycoords='data',
                            boxcoords="offset points",
                            pad=0.3,
                            arrowprops=dict(arrowstyle="->",
                                            linestyle="dashed"))
        ax.add_artist(ab)
        t = d
        if frequency == '1d':
            t = d.date()
        ax.annotate('{0}: {1:.2f}@{2}'.format(t, p, g),
                    xy=(d, p),
                    xytext=(0, 0),
                    textcoords="offset points")
    if frequency == '1d':
        fig.autofmt_xdate()
        ax.xaxis.set_major_formatter(DateFormatter('%Y-%m-%d'))
    else:
        ax.set_xticks([])
    if grid is not None:
        ax.set_yticks(grid)
        ax.grid(visible=True, which='major', axis='y', linestyle='-')
        #  grid_ids = {}
        #  for idx in range(grid.shape[0] - 1):
        #  grid_ids['%s' % (idx + 1)] = (grid[idx] + grid[idx + 1]) / 2
        #  ax.set_yticks(list(grid_ids.values()), list(grid_ids.keys()), minor=True)
        #  ax.grid(visible=True, which='minor', axis='y', linestyle='')
    return fig


def plot_rqalpha_backtest_results(backtest_time,
                                  start_date,
                                  end_date,
                                  results,
                                  savefile=None):
    rcParams['font.family'] = 'sans-serif'
    rcParams['font.sans-serif'] = [
        u'Microsoft Yahei',
        u'Heiti SC',
        u'Heiti TC',
        u'STHeiti',
        u'WenQuanYi Zen Hei',
        u'WenQuanYi Micro Hei',
        u'文泉驿微米黑',
        u'SimHei',
    ] + rcParams['font.sans-serif']
    rcParams['axes.unicode_minus'] = False
    title = '策略回测比较'
    plt.style.use('ggplot')
    img_width = 20
    img_height = 12
    fig = plt.figure(title, figsize=(img_width, img_height))
    gs = gridspec.GridSpec(img_height, img_width)
    ax = plt.subplot(gs[2:, :])
    ax.get_xaxis().set_minor_locator(ticker.AutoMinorLocator())
    ax.get_yaxis().set_minor_locator(ticker.AutoMinorLocator())
    ax.grid(b=True, which='minor', linewidth=.2)
    ax.grid(b=True, which='major', linewidth=1)
    table_data = {}
    keys = pd.DataFrame([{
        'strategy': strategy,
        'returns': result_dict['summary']['total_returns']
    } for strategy, result_dict in results.items()])
    keys.sort_values('returns', ascending=False, inplace=True)
    table_columns = [
        'sharpe', 'max_drawdown', 'total_returns', 'annualized_returns'
    ]
    for strategy in keys.strategy.tolist():
        result_dict = results[strategy]
        summary = result_dict['summary']
        table_data[strategy] = [summary[col] for col in table_columns]
        portfolio = result_dict['portfolio']
        returns = portfolio['unit_net_value'] - 1.0
        p = ax.plot(returns, label=strategy, alpha=1, linewidth=2)
        color = p[0].get_color()
        max_dd = _max_dd(returns.values, portfolio.index)
        ax.plot(returns.index[max_dd.start],
                returns[max_dd.start],
                MAX_DD.marker,
                color=color,
                markersize=MAX_DD.markersize,
                alpha=MAX_DD.alpha)
        ax.plot(returns.index[max_dd.end],
                returns[max_dd.end],
                MAX_DD.marker,
                color=color,
                markersize=MAX_DD.markersize,
                alpha=MAX_DD.alpha)
        # max_drawdown column
        table_data[strategy].append(
            '%s - %s' % (returns.index[max_dd.start].strftime('%Y-%m-%d'),
                         returns.index[max_dd.end].strftime('%Y-%m-%d')))
        #  max_ddd = _max_ddd(returns.values, portfolio.index)
        #  ax.plot(returns.index[max_ddd.start],
        #  returns.index[max_ddd.end],
        #  MAX_DDD.marker,
        #  color=color,
        #  markersize=MAX_DDD.markersize,
        #  alpha=alpha,
        #  label='max_ddd %s' % strategy)

    # place legend
    leg = plt.legend(loc='best')
    leg.get_frame().set_alpha(0.5)

    # manipulate axis
    vals = ax.get_yticks()
    ax.set_yticklabels(['{:3.2f}%'.format(x * 100) for x in vals])

    df = pd.DataFrame.from_dict(table_data,
                                orient='index',
                                columns=table_columns +
                                ['max_dd']).reset_index().rename(
                                    columns={'index': 'strategy'},
                                    errors='raise')
    df[['max_drawdown', 'total_returns', 'annualized_returns'
        ]] = df[['max_drawdown', 'total_returns',
                 'annualized_returns']].applymap('{0:.2%}'.format)
    ax2 = plt.subplot(gs[0:2, :])
    ax2.set_title(title)
    ax2.text(
        0, 2,
        '开始时间: %s, 结束时间: %s, 回测时间: %s' % (start_date, end_date, backtest_time))
    ax2.table(cellText=df.values,
              cellLoc='center',
              colLabels=['策略', '夏普比率', '最大回撤', '总收益率', '年化收益率', '最大回撤区间'],
              loc='center')
    ax2.axis('off')

    if savefile:
        plt.savefig(savefile, bbox_inches='tight')
    return fig


def plot_bars(df):
    fig = px.line(df, render_mode="svg")
    fig.update_xaxes(rangebreaks=[
        dict(bounds=[15.01, 9.5], pattern='hour'),
        dict(bounds=[11.51, 13], pattern='hour'),
        dict(bounds=["sat", "mon"]),  # hide weekends
    ])
    fig.show()